# Copyright 2019 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Exploitability measurement utils for general (sym and non-sym) games."""

from absl import logging  # pylint:disable=unused-import

import numpy as np
from scipy import special

from open_spiel.python.algorithms.adidas_utils.helpers import misc


def unreg_exploitability(dist, payoff_tensor, aggregate=np.mean):
  """Compute (avg, max) exploitability of dist for non-symmetric game.

  Args:
    dist: list of 1-d np.arrays, current estimate of nash distribution
    payoff_tensor: (n x A1 x ... x An) np.array, payoffs for each joint action
      can also be list of (A1 x ... x An) np.arrays
    aggregate: function to reduce individual exp_is to scalar, e.g., mean or max
  Returns:
    exploitability (float): avg_i payoff_i of best response_i - payoff_i of dist
  """
  num_players = len(payoff_tensor)

  exp_i = []
  for i in range(num_players):
    nabla_i = misc.pt_reduce(payoff_tensor[i], dist, [i])
    u_i_br = np.max(nabla_i)
    u_i_dist = nabla_i.dot(dist[i])
    exp_i.append(u_i_br - u_i_dist)

  return aggregate(exp_i)


def ate_exploitability(dist, payoff_tensor, p=1, aggregate=np.mean):
  """Compute Tsallis regularized exploitability of dist for non-symmetric game.

  Args:
    dist: list of 1-d np.arrays, current estimate of nash distribution
    payoff_tensor: (n x A1 x ... x An) np.array, payoffs for each joint action
      assumed to be non-negative. can also be list of (A1 x ... x An) np.arrays
    p: float in [0, 1], Tsallis entropy-regularization --> 0 as p --> 0
    aggregate: function to reduce individual exp_is to scalar, e.g., mean or max
  Returns:
    exploitability (float): avg_i payoff_i of best response_i - payoff_i of dist
  """
  if np.min(payoff_tensor) < 0.:
    raise ValueError('payoff tensor must be non-negative')
  num_players = len(payoff_tensor)

  exp_i = []
  for i in range(num_players):
    nabla_i = misc.pt_reduce(payoff_tensor[i], dist, [i])
    dist_i = dist[i]
    if p > 0:
      power = 1./p
      s = np.linalg.norm(nabla_i, ord=power)
      br_i = (nabla_i / np.linalg.norm(nabla_i, ord=power))**power
    else:
      power = np.inf
      s = np.linalg.norm(nabla_i, ord=power)
      br_i = np.zeros_like(dist_i)
      maxima = (nabla_i == s)
      br_i[maxima] = 1. / maxima.sum()

    u_i_br = nabla_i.dot(br_i) + s / (p + 1) * (1 - np.sum(br_i**(p + 1)))
    u_i_dist = nabla_i.dot(dist_i) + s / (p + 1) * (1 - np.sum(dist_i**(p + 1)))

    exp_i.append(u_i_br - u_i_dist)

  return aggregate(exp_i)


def qre_exploitability(dist, payoff_tensor, temperature=0., aggregate=np.mean):
  """Compute Shannon regularized exploitability of dist for non-symmetric game.

  Args:
    dist: list of 1-d np.arrays, current estimate of nash distribution
    payoff_tensor: (n x A1 x ... x An) np.array, payoffs for each joint action
      assumed to be non-negative. can also be list of (A1 x ... x An) np.arrays
    temperature: non-negative float
    aggregate: function to reduce individual exp_is to scalar, e.g., mean or max
  Returns:
    exploitability (float): avg_i payoff_i of best response_i - payoff_i of dist
  """
  num_players = len(payoff_tensor)

  exp_i = []
  for i in range(num_players):
    nabla_i = misc.pt_reduce(payoff_tensor[i], dist, [i])
    dist_i = dist[i]
    if temperature > 0:
      br_i = special.softmax(nabla_i / temperature)
    else:
      br_i = np.zeros_like(dist_i)
      maxima = (nabla_i == np.max(nabla_i))
      br_i[maxima] = 1. / maxima.sum()

    u_i_br = nabla_i.dot(br_i) + temperature * special.entr(br_i).sum()
    u_i_dist = nabla_i.dot(dist_i) + temperature * special.entr(dist_i).sum()

    exp_i.append(u_i_br - u_i_dist)

  return aggregate(exp_i)
